/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2019-2020. All rights reserved.
 * Description: op_ir_func_factory
 */

#ifndef FRAMEWORK_GRAPH_INFERSHAPE_OP_IR_INFER_UTIL_H
#define FRAMEWORK_GRAPH_INFERSHAPE_OP_IR_INFER_UTIL_H

#include <cmath>
#include <string>
#include <vector>
#include <set>
#include <cfloat>

// inc/api
#include "graph/types.h"
#include "graph/debug/ge_error_codes.h"

// inc/framework
#include "framework/graph/core/infershape/op_ir_ctx.h"

namespace ge {
#define FLOAT_GREATER_EQUAL(a, b) (((a) - (b)) >= -FLT_EPSILON ? true : false)
#define FLOAT_GREATER(a, b) (((a) - (b)) > FLT_EPSILON ? true : false)
#define FLOAT_LESS_EQUAL(a, b) (((a) - (b)) <= FLT_EPSILON ? true : false)
#define FLOAT_LESS(a, b) (((a) - (b)) < -FLT_EPSILON ? true : false)
#define FLOAT_EQUAL(a, b) (fabs(((a) - (b))) <= (FLT_EPSILON) ? true : false)
#define FLOAT_NOT_EQUAL(a, b) (fabs(((a) - (b))) > (FLT_EPSILON) ? true : false)
class InferUtil {
public:
    // Shape 相关校验
    // 校验输入个数等于固定值
    GRAPH_API_EXPORT static GraphErrCodeStatus VerifyInputSize(InferContext& context, size_t expectSize);

    static GraphErrCodeStatus VerifyInputSizeRange(InferContext& context, size_t minSize, size_t maxSize);

    // 算子所有的输入的DataType相同，一个函数校验所有输入的数据类型
    GRAPH_API_EXPORT static GraphErrCodeStatus VerifyInputDataType(
        InferContext& context, const std::set<DataType>& supportDataType);

    // 根据输入inputIndex，校验支持的数据类型。
    GRAPH_API_EXPORT static GraphErrCodeStatus VerifyInputDataType(
        InferContext& context, uint32_t inputIndex, const std::set<DataType>& supportDataType);

    static GraphErrCodeStatus VerifyInputDataType(InferContext& context, uint32_t index, DataType supportDataType);

    static GraphErrCodeStatus VerifyAxis(InferContext& context, int64_t axis, uint32_t inputIndex);

    static GraphErrCodeStatus VerifyAxisList(InferContext& context, std::vector<int64_t> axisList, uint32_t inputIndex);

    static GraphErrCodeStatus VerifySameDataType(InferContext& context, uint32_t inputIndex1, uint32_t inputIndex2);
    static GraphErrCodeStatus VerifySameDimNum(InferContext& context, uint32_t inputIndex1, uint32_t inputIndex2);

    static GraphErrCodeStatus VerifyDimNumEqualTo(InferContext& context, uint32_t inputIndex, size_t value);
    static GraphErrCodeStatus VerifyDimNumInRange(
        InferContext& context, uint32_t inputIndex, size_t minValue, size_t maxValue);

    static GraphErrCodeStatus VerifyDimNumGreaterOrEqual(InferContext& context, uint32_t inputIndex, size_t value);

    static GraphErrCodeStatus VerifyDimNumLessOrEqual(InferContext& context, uint32_t inputIndex, size_t value);
    static GraphErrCodeStatus VerifyNonScalarInput(InferContext& context, uint32_t inputIndex);
    // 校验输入必须是Const op
    static GraphErrCodeStatus VerifyConstInput(InferContext& context, uint32_t inputIndex);
    // 校验shape是否相同
    static GraphErrCodeStatus VerifySameShape(InferContext& context, std::vector<int64_t> shape1,
        std::vector<int64_t> shape2);

    static GraphErrCodeStatus VerifyAttrRange(
        InferContext& context, const string& attrName, int64_t value, int64_t expectMinValue, int64_t expectMaxValue);

    static GraphErrCodeStatus VerifyAttrFloatRange(
        InferContext& context, const string& attrName, float value, float expectMinValue, float expectMaxValue);

    static GraphErrCodeStatus VerifyAttrEqualTo(
        InferContext& context, const string& attrName, int64_t value, int64_t dstValue);

    static GraphErrCodeStatus VerifyAttrBoolEqualTo(
        InferContext& context, const string& attrName, bool value, bool dstValue);

    static GraphErrCodeStatus VerifyAttrGreaterThan(
        InferContext& context, const string& attrName, int64_t value, int64_t dstValue);

    static GraphErrCodeStatus VerifyAttrFloatGreaterThan(
        InferContext& context, const string& attrName, float value, float dstValue);

    static GraphErrCodeStatus VerifyAttrGreaterOrEqual(
        InferContext& context, const string& attrName, int64_t value, int64_t dstValue);

    static GraphErrCodeStatus VerifyAttrFloatGreaterOrEqual(
        InferContext& context, const string& attrName, float value, float dstValue);

    static GraphErrCodeStatus VerifyAttrLessThan(
        InferContext& context, const string& attrName, int64_t value, int64_t dstValue);

    static GraphErrCodeStatus VerifyAttrLessOrEqual(
        InferContext& context, const string& attrName, int64_t value, int64_t dstValue);

    static GraphErrCodeStatus VerifyAttrListLengthEqualTo(
        InferContext& context, const string& attrName, size_t length, size_t expectLenght);

    // 校验必选属性
    static GraphErrCodeStatus VerifyRequiredAttr(InferContext& context, const string& attrName);

    // 校验str类型属性是否在取值范围之类
    static GraphErrCodeStatus VerifyStrAttrInRange(
        InferContext& context, const string& attrName, string strValue, const std::vector<string>& expectStrArr);

    // 校验Int类型属性是否在取值范围之类
    static GraphErrCodeStatus VerifyIntAttrInRange(
        InferContext& context, const string& attrName, int64_t intValue, const std::vector<int64_t>& expectIntArr);

    // 校验str类型属性是否等于约束属性
    static GraphErrCodeStatus VerifyStrAttrEqualTo(
        InferContext& context, const string& attrName, string strValue, const string& expectStrArr);

    // Convolution PoolingD ConvTranspose 1d场景属性补齐
    static GraphErrCodeStatus PaddingAttrs(std::vector<int64_t>& srcValues, int64_t insertValue);

    // Convolution PoolingD ConvTranspose 1d场景维度补齐
    static GraphErrCodeStatus PaddingInputs(Shape& shape, const std::string dataFormat);

    GRAPH_API_EXPORT static GraphErrCodeStatus InferSubGraphShape(const ge::ComputeGraphPtr& subGraph,
        InferContext& inferContext, std::vector<ge::TensorDesc>& inputTensorsDesc,
        std::vector<ge::TensorDesc>& outputTensorsDesc);
    GRAPH_API_EXPORT static GraphErrCodeStatus UpdateSubGraphInputs(
        const ge::ComputeGraphPtr& subGraph, std::vector<ge::TensorDesc>& inputTensorsDesc);
    GRAPH_API_EXPORT static GraphErrCodeStatus GetSubGraphOutputs(
        const ge::ComputeGraphPtr& subGraph, std::vector<ge::TensorDesc>& outputTensorsDesc);
};

GraphErrCodeStatus OneFloatInputVerify(InferContext& context);
GraphErrCodeStatus TwoFloatInputsVerify(InferContext& context);
GraphErrCodeStatus TwoFloatIntInputsVerify(InferContext& context);
GraphErrCodeStatus TwoFloatAndIntsInputsVerify(InferContext& context);
GraphErrCodeStatus TwoBoolInputsVerify(InferContext& context);
}; // namespace ge

#endif // FRAMEWORK_GRAPH_INFERSHAPE_OP_IR_INFER_UTIL_H